package draw

import (
	"encoding/base64"
	"encoding/json"
	"fmt"
	"github.com/google/uuid"
	"github.com/pkg/errors"
	"github.com/zeromicro/go-zero/core/logx"
	"io"
	"net/http"
	"os"
	"strconv"
	"strings"
)

type Draw interface {
	Txt2Img(prompt string, path chan string) error // 文字转图片
}

// SdDraw Stable diffusion draw
type SdDraw struct {
	Host     string
	Username string
	Password string
}

func NewSdDraw(host, name, password string) *SdDraw {
	return &SdDraw{
		Host:     host,
		Username: name,
		Password: password,
	}
}

func (sd *SdDraw) Txt2Img(prompt string, ch chan string) error {
	url := sd.Host + "/sdapi/v1/txt2img"

	// 对 prompt 比如 进行解析
	//masterpiece, best quality, highres, 1girl, suzumiya haruhi, solo, kita high school uniform, blue sailor collar,  sailor collar, blue skirt, brown hair, short hair, brown eyes, armband, hairband, medium hair, ribbon, socks, medium breasts, <lora:suzumiya_haruhi_v10:0.7>, classroom, <lora:LookingDisgusted_V1:0.3>,show feet, look at viewer,image taken from below, full body, shy, sitting on the desk,((looking disgusted)), (very angry), disappointed
	//Negative prompt: EasyNegative, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, (worst quality:1.2), low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, lowres graffiti, (low quality lowres simple background:1.1),
	//ENSD: 31337, Size: 512x768, Seed: 1002205557, Model: OldFish_2348V12, Steps: 20, Sampler: DPM++ SDE Karras, CFG scale: 7, Clip skip: 2, Model hash: bd399ee451, Hires steps: 13, Hires upscale: 2, Hires upscaler: R-ESRGAN 4x+ Anime6B, Denoising strength: 0.45
	reqPayload := getDefaultDataTXT2IMGReq()
	// 先对"\n"进行分段
	if !strings.Contains(prompt, "\n") {
		reqPayload.Prompt += prompt
	} else {
		//通过 "\n"  分段
		for k, val := range strings.Split(prompt, "\n") {
			// 正面提示
			if k == 0 {
				reqPayload.Prompt = val
				continue
			}
			// 负面提示配置
			if strings.Contains(val, "Negative prompt:") {
				reqPayload.NegativePrompt = strings.Replace(val, "Negative prompt:", "", -1)
				continue
			}
			// 其它配置
			if strings.Contains(val, "Steps:") {
				for _, v := range strings.Split(prompt, ", ") {
					if strings.Contains(v, "Steps:") {
						s := strings.TrimSpace(strings.Replace(v, "Steps:", "", -1))
						// 转 int
						reqPayload.Steps, _ = strconv.Atoi(s)
					}
					if strings.Contains(v, "Sampler:") {
						reqPayload.SamplerName = strings.TrimSpace(strings.Replace(v, "Sampler:", "", -1))
					}
					if strings.Contains(v, "CFG scale:") {
						s := strings.TrimSpace(strings.Replace(v, "CFG scale:", "", -1))
						// 转 int
						reqPayload.CfgScale, _ = strconv.Atoi(s)
					}
					if strings.Contains(v, "Seed:") {
						s := strings.TrimSpace(strings.Replace(v, "Seed:", "", -1))
						// 转 int64
						seed, err := strconv.ParseInt(s, 10, 64)
						if err == nil {
							fmt.Printf("%T, %v\n", s, s)
							reqPayload.Seed = seed
						}
					}
					if strings.Contains(v, "Size:") {
						s := strings.TrimSpace(strings.Replace(v, "Size:", "", -1))
						// 转 int
						// Size: 512x768, => 512, 768
						size := strings.Split(s, "x")
						if len(size) == 2 {
							reqPayload.Width, _ = strconv.Atoi(size[0])
							reqPayload.Height, _ = strconv.Atoi(size[1])
						}
					}
					//Denoising strength: 0.52,
					if strings.Contains(v, "Denoising strength:") {
						s := strings.TrimSpace(strings.Replace(v, "Denoising strength:", "", -1))
						// 转 float64
						strength, err := strconv.ParseFloat(s, 64)
						if err == nil {
							reqPayload.DenoisingStrength = strength
						}
					}
				}
			}
		}
	}

	client := &http.Client{}
	body, _ := json.Marshal(reqPayload)
	drawReq, err := http.NewRequest(http.MethodPost, url, strings.NewReader(string(body)))
	if err != nil {
		logx.Info("draw request client build fail", err)
		return errors.New("构建绘画请求失败，请重新尝试~")
	}
	logx.Info("draw request client build success")
	drawReq.Header.Add("Content-Type", "application/json")
	drawReq.Header.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(sd.Username+":"+sd.Password)))

	ch <- "start"

	res, err := client.Do(drawReq)
	if err != nil {
		logx.Info("draw request fail", err)
		return errors.New("绘画请求失败，请重新尝试~")
	}
	defer func(Body io.ReadCloser) {
		_ = Body.Close()
	}(res.Body)

	resBody, err := io.ReadAll(res.Body)
	if err != nil {
		logx.Info("draw request fail", err)
		return errors.New("绘画请求响应失败，请重新尝试~")
	}

	var resPayload map[string]interface{}
	err = json.Unmarshal(resBody, &resPayload)
	if err != nil {
		logx.Info("resBody", string(resBody))
		logx.Info("draw request fail", err)
		return errors.New("绘画请求响应解析失败，请重新尝试~")
	}
	images := resPayload["images"].([]interface{})
	for _, image := range images {
		s := image.(string)
		if err != nil {
			logx.Info("draw request fail", err)
			return errors.New("绘画请求响应解析失败，请重新尝试~")
		}
		// 将解密后的信息写入到本地
		imageBase64 := strings.Split(s, ",")[0]
		decodeBytes, err := base64.StdEncoding.DecodeString(imageBase64)
		if err != nil {
			logx.Info("draw request fail", err)
			return errors.New("绘画请求响应解析失败，请重新尝试~")
		}

		// 判断目录是否存在
		_, err = os.Stat("/tmp/image")
		if err != nil {
			err := os.MkdirAll("/tmp/image", os.ModePerm)
			if err != nil {
				fmt.Println("mkdir err:", err)
				return errors.New("绘画请求响应解析失败，请重新尝试~")
			}
		}

		path := fmt.Sprintf("/tmp/image/%s.png", uuid.New().String())

		err = os.WriteFile(path, decodeBytes, os.ModePerm)

		if err != nil {
			logx.Info("draw save fail", err)
			return errors.New("绘画请求响应解析失败，请重新尝试~")
		}

		// 再将 image 信息发送到用户
		ch <- path
	}

	ch <- "stop"

	return nil
}

// TXT2IMGReq TXT2IMG Req
type TXT2IMGReq struct {
	Prompt            string  `json:"prompt"`               // 正面提示
	NegativePrompt    string  `json:"negative_prompt"`      // 负面提示
	Steps             int     `json:"steps"`                // 生成图片的步数z
	Width             int     `json:"width"`                // 生成图片的宽
	Height            int     `json:"height"`               // 生成图片的高
	SamplerName       string  `json:"sampler_name"`         // 采样器名称
	BatchSize         int     `json:"batch_size"`           // 批量出图数量,默认为1
	CfgScale          int     `json:"cfg_scale"`            // 提示词相关性,默认为7
	Seed              int64   `json:"seed"`                 // 随机种子,默认为-1
	DenoisingStrength float64 `json:"denoising_strength"`   // 去噪强度,默认为0
	EnableHr          bool    `json:"enable_hr"`            // 是否开启高分辨率,默认为false
	HrScale           int     `json:"hr_scale"`             // 高分辨率倍数,默认为2
	HrUpscaler        string  `json:"hr_upscaler"`          // 高分辨率倍数,默认为2
	HrSecondPassSteps int     `json:"hr_second_pass_steps"` // 高分辨率倍数,默认为0
	HrResizeX         int     `json:"hr_resize_x"`          // 高分辨率倍数,默认为0
	HrResizeY         int     `json:"hr_resize_y"`          // 高分辨率倍数,默认为0
}

func getDefaultDataTXT2IMGReq() TXT2IMGReq {
	t := TXT2IMGReq{
		//Prompt:            "masterpiece, best quality,Amazing,finely detail,Depth of field,extremely detailed CG unity 8k wallpaper,",
		Prompt:            "masterpiece, best quality,Amazing,finely detail,",
		NegativePrompt:    "(worst quality:1.25), (low quality:1.25), (lowres:1.1), (monochrome:1.1), (greyscale), multiple views, comic, sketch, (blurry:1.05),",
		Steps:             20,
		Width:             512,
		Height:            512,
		SamplerName:       "DPM++ SDE Karras",
		BatchSize:         1,
		CfgScale:          7,
		Seed:              -1,
		DenoisingStrength: 0,
		EnableHr:          false,
	}
	return t
}
